{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
    "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
    "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import json\n",
    "from datetime import datetime\n",
    "\n",
    "import jax\n",
    "import matplotlib.pyplot as plt\n",
    "import mediapy as media\n",
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from etils import epath\n",
    "from flax.training import orbax_utils\n",
    "from IPython.display import clear_output, display\n",
    "from orbax import checkpoint as ocp\n",
    "\n",
    "from mujoco_playground import wrapper, manipulation\n",
    "\n",
    "# Enable persistent compilation cache.\n",
    "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
    "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
    "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"PandaOpenCabinet\"\n",
    "env_cfg = manipulation.get_default_config(env_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SUFFIX = None\n",
    "FINETUNE_PATH = None\n",
    "\n",
    "# Generate unique experiment name.\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%Y%m%d-%H%M%S\")\n",
    "exp_name = f\"{env_name}/{timestamp}\"\n",
    "if SUFFIX is not None:\n",
    "  exp_name += f\"-{SUFFIX}\"\n",
    "print(f\"Experiment name: {exp_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_networks_factory = functools.partial(\n",
    "    ppo_networks.make_ppo_networks, policy_hidden_layer_sizes=(32, 32, 32, 32)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = epath.Path(\"checkpoints\").resolve() / exp_name\n",
    "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
    "print(f\"Checkpoint path: {ckpt_path}\")\n",
    "\n",
    "with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
    "  json.dump(env_cfg.to_dict(), fp, indent=4)\n",
    "\n",
    "\n",
    "def policy_params_fn(current_step, make_policy, params):\n",
    "  orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
    "  save_args = orbax_utils.save_args_from_target(params)\n",
    "  path = ckpt_path / f\"{current_step}\"\n",
    "  orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
    "\n",
    "from mujoco_playground.config import manipulation_params\n",
    "\n",
    "_train_params = manipulation_params.brax_ppo_config(env_name)\n",
    "train_params = dict(_train_params)\n",
    "train_params['seed'] = 1\n",
    "del train_params[\"network_factory\"]\n",
    "\n",
    "train_fn = functools.partial(\n",
    "  ppo.train,\n",
    "  **dict(train_params),\n",
    "  network_factory=functools.partial(\n",
    "    ppo_networks.make_ppo_networks,\n",
    "    policy_hidden_layer_sizes=_train_params.network_factory.policy_hidden_layer_sizes\n",
    "  ))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_data, y_data, y_dataerr = [], [], []\n",
    "times = [datetime.now()]\n",
    "\n",
    "\n",
    "def progress(num_steps, metrics):\n",
    "  # Plot.\n",
    "  clear_output(wait=True)\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics[\"eval/episode_reward\"])\n",
    "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
    "\n",
    "  # Performance.\n",
    "  if len(x_data) >= 2:\n",
    "    num = x_data[-1] - x_data[-2]\n",
    "    denom = (times[-1] - times[-2]).total_seconds()\n",
    "    fps = num / denom\n",
    "    print(f\"Training at {fps} FPS\")\n",
    "\n",
    "  plt.xlim([0, train_fn.keywords[\"num_timesteps\"] * 1.25])\n",
    "  # plt.ylim([0, YLIM[env_name]])\n",
    "  plt.xlabel(\"# environment steps\")\n",
    "  plt.ylabel(\"reward per episode\")\n",
    "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
    "\n",
    "  display(plt.gcf())\n",
    "\n",
    "\n",
    "env = manipulation.load(env_name, config=env_cfg)\n",
    "make_inference_fn, params, _ = train_fn(\n",
    "    wrap_env=False,\n",
    "    environment=wrapper.wrap_for_brax_training(env,\n",
    "                                               episode_length=env_cfg.episode_length,\n",
    "                                               action_repeat=env_cfg.action_repeat),\n",
    "                                               progress_fn=progress\n",
    ")\n",
    "print(f\"time to jit: {times[1] - times[0]}\")\n",
    "print(f\"time to train: {times[-1] - times[1]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = manipulation.load(env_name, config=env_cfg)\n",
    "\n",
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)\n",
    "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(5)\n",
    "key, key_reset = jax.random.split(key)\n",
    "state = jit_reset(key_reset)\n",
    "states = [state]\n",
    "\n",
    "render_every = 2  # Policy is 50 FPS\n",
    "\n",
    "for i in range(125):\n",
    "  act_rng, key = jax.random.split(key)\n",
    "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "  state = jit_step(state, ctrl)\n",
    "  if i % render_every == 0:\n",
    "    states.append(state)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "media.show_video(\n",
    "    env.render(states, height=480, width=640),\n",
    "    fps=1.0 / env.dt / render_every,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "franka_cartesian_pr",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
